/*
 * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
package jdk.graal.compiler.vector.replacements.vectorapi.nodes;

import static jdk.graal.compiler.replacements.nodes.MacroNode.MacroParams;

import java.util.List;

import org.graalvm.collections.EconomicMap;

import jdk.graal.compiler.core.common.type.ArithmeticOpTable;
import jdk.graal.compiler.core.common.type.Stamp;
import jdk.graal.compiler.debug.GraalError;
import jdk.graal.compiler.graph.Node;
import jdk.graal.compiler.graph.NodeClass;
import jdk.graal.compiler.graph.NodeMap;
import jdk.graal.compiler.nodeinfo.NodeInfo;
import jdk.graal.compiler.nodes.FrameState;
import jdk.graal.compiler.nodes.NodeView;
import jdk.graal.compiler.nodes.ValueNode;
import jdk.graal.compiler.nodes.calc.BinaryArithmeticNode;
import jdk.graal.compiler.nodes.spi.Canonicalizable;
import jdk.graal.compiler.nodes.spi.CanonicalizerTool;
import jdk.graal.compiler.nodes.spi.CoreProviders;
import jdk.graal.compiler.vector.architecture.VectorArchitecture;
import jdk.graal.compiler.vector.nodes.simd.SimdCutNode;
import jdk.graal.compiler.vector.nodes.simd.SimdStamp;
import jdk.vm.ci.code.CodeUtil;

/**
 * Intrinsic node for the {@code VectorSupport.reductionCoerced} method. This operation applies a
 * binary arithmetic operation across the elements of a vector {@code v} and reinterprets the scalar
 * result bits as a {@code long}:
 * <p/>
 *
 * {@code
 * tmp = OP(v.0, OP(v.1, OP(..., v.n)));
 * result = reinterpretAsLong(tmp)
 * }
 *
 * <p/>
 * A mask is not currently supported by Graal. The binary operation is identified by an integer
 * opcode which we map to the corresponding Graal operation.
 * <p/>
 *
 * In the Vector API, the order of the operations in the reduction "is intentionally not defined".
 * As floating-point arithmetic is not associative, the exact sum or product of a floating-point
 * vector will therefore depend on choices made by the compiler (we make no effort to match the code
 * generated by C2), the vector length, and the intermediate operations supported by the target
 * architecture. Thus, in general, a floating-point result produced by the intrinsic will not be
 * equal to the value computed by the interpreter on the original Java method.
 */
@NodeInfo(nameTemplate = "VectorAPIReductionCoerced {p#op/s}")
public class VectorAPIReductionCoercedNode extends VectorAPISinkNode implements Canonicalizable {

    public static final NodeClass<VectorAPIReductionCoercedNode> TYPE = NodeClass.create(VectorAPIReductionCoercedNode.class);

    private final SimdStamp inputVectorStamp;
    private final ArithmeticOpTable.BinaryOp<?> op;

    /* Indices into the macro argument list for relevant input values. */
    private static final int OPRID_ARG_INDEX = 0;
    private static final int VCLASS_ARG_INDEX = 1;
    private static final int ECLASS_ARG_INDEX = 3;
    private static final int LENGTH_ARG_INDEX = 4;
    private static final int VALUE_ARG_INDEX = 5;
    private static final int MASK_ARG_INDEX = 6;

    protected VectorAPIReductionCoercedNode(MacroParams macroParams, SimdStamp inputVectorStamp, ArithmeticOpTable.BinaryOp<?> op, FrameState stateAfter) {
        super(TYPE, macroParams);
        this.inputVectorStamp = inputVectorStamp;
        this.op = op;
        this.stateAfter = stateAfter;
    }

    public static VectorAPIReductionCoercedNode create(MacroParams macroParams, CoreProviders providers) {
        SimdStamp inputVectorStamp = improveVectorStamp(null, macroParams.arguments, VCLASS_ARG_INDEX, ECLASS_ARG_INDEX, LENGTH_ARG_INDEX, providers);
        ArithmeticOpTable.BinaryOp<?> op = improveBinaryOp(null, macroParams.arguments, OPRID_ARG_INDEX, inputVectorStamp, providers);
        return new VectorAPIReductionCoercedNode(macroParams, inputVectorStamp, op, null);
    }

    public ValueNode inputVector() {
        return getArgument(VALUE_ARG_INDEX);
    }

    @Override
    public Iterable<ValueNode> vectorInputs() {
        return List.of(inputVector());
    }

    @Override
    public Node canonical(CanonicalizerTool tool) {
        if (inputVectorStamp != null && op != null) {
            /* Nothing to improve. */
            return this;
        }

        ValueNode[] args = toArgumentArray();
        SimdStamp newVectorStamp = improveVectorStamp(inputVectorStamp, args, VCLASS_ARG_INDEX, ECLASS_ARG_INDEX, LENGTH_ARG_INDEX, tool);
        ArithmeticOpTable.BinaryOp<?> newOp = improveBinaryOp(op, args, OPRID_ARG_INDEX, newVectorStamp, tool);
        if (newVectorStamp != inputVectorStamp || (newOp != null && !newOp.equals(op))) {
            return new VectorAPIReductionCoercedNode(copyParams(), newVectorStamp, newOp, stateAfter());
        }
        return this;
    }

    @Override
    public boolean canExpand(VectorArchitecture vectorArch, EconomicMap<ValueNode, Stamp> simdStamps) {
        if (inputVectorStamp == null || op == null) {
            return false;
        }
        if (!getArgument(MASK_ARG_INDEX).isNullConstant()) {
            return false;
        }
        /*
         * No need to check the vector architecture, an expansion is always possible. In the worst
         * case we must extract each element of the input vector one by one and do the reduction in
         * scalar code.
         */
        return true;
    }

    @Override
    public ValueNode expand(VectorArchitecture vectorArch, NodeMap<ValueNode> expanded) {
        ValueNode currentVector = expanded.get(inputVector());
        SimdStamp simdStamp = (SimdStamp) currentVector.stamp(NodeView.DEFAULT);
        Stamp elementStamp = simdStamp.getComponent(0);
        boolean isInteger = elementStamp.isIntegerStamp();
        int currentLength = simdStamp.getVectorLength();
        GraalError.guarantee(CodeUtil.isPowerOf2(currentLength), "expected power of 2 vector length: %s", simdStamp);

        /**
         * Build a binary cascade like:
         *
         * <pre>
         *   SimdCut(0, length/2)  SimdCut(length/2, length/2)
         *                      \  /
         *                       op
         *                      /  \
         *   SimdCut(0, length/4)  SimdCut(length/4, length/4)
         *                      \  /
         *                       op
         *                       ...
         * </pre>
         *
         * as long as the half of the current vector length is a legal vector length for the binary
         * operation.
         */
        while (vectorArch.getSupportedVectorArithmeticLength(elementStamp, currentLength / 2, op) == currentLength / 2) {
            ValueNode left = new SimdCutNode(currentVector, 0, currentLength / 2);
            ValueNode right = new SimdCutNode(currentVector, currentLength / 2, currentLength / 2);
            currentVector = isInteger ? BinaryArithmeticNode.binaryIntegerOp(left, right, NodeView.DEFAULT, op) : BinaryArithmeticNode.binaryFloatOp(left, right, NodeView.DEFAULT, op);
            currentVector.inferStamp();
            currentLength /= 2;
        }
        ValueNode result = currentVector;

        /* Once we've reached an unsupported length, finish with a linear cascade. */
        if (currentLength > 1) {
            ValueNode acc = new SimdCutNode(currentVector, 0, 1);
            for (int iteration = 1; iteration < currentLength; iteration++) {
                ValueNode value = new SimdCutNode(currentVector, iteration, 1);
                acc = isInteger ? BinaryArithmeticNode.binaryIntegerOp(acc, value, NodeView.DEFAULT, op) : BinaryArithmeticNode.binaryFloatOp(acc, value, NodeView.DEFAULT, op);
            }
            result = acc;
        }

        return reinterpretAsLong(result);
    }
}
